function savings_cv_cubic = generate_results_cross_validation_cubic(fold,split,reweight,sample,sample_training,sample_testing,sample_wave,target_wave,tcost,covariate,winter_prevuse,SMC,savedir,fbase)
% This function estimates EWM cubic rules on the training set and evaluates
% cost savings or energy reduction on the testing set. 

% Inputs: 
% (1) fold: permutation id
% (2) split: share of observations in the testing set
% (3) reweight: use density ratio to reweight waves or not
% (4) sample: cross-sectional dataset merged using elec_wave367_cross.csv
% and propensity scores and renamed as data_estimation_pooled_ps
% (5) sample_training: training sample
% (6) sample_testing: testing sample
% (7) sample_wave: set to 0 to indicate full estimation sample
% (8) target_wave: set to 0 to indicate full estimation sample 
% (4) tcost: private marginal cost of implementing the program per household
%  >0 represents cost savings; =0 represents kWh reduction 
% (5) covariate: indicates covariate for analysis: income, size, vintage,
% minimum of baseline consumption, maximum of baseline consumption, or
% standard deviation of consumption
% (6) winter_prevuse: indicates whether baseline consumption is calculated 
% as the mean of consumption in winter months (Jan and Feb) or as the mean
% of specified pre-treatment periods.
% (7) SMC: =1 use social marginal cost; =0 use retail electricity price
% (8) savedir: directory to save the savings table 
% (9) fbase: string used to indicate the propensity score and baseline
% months specification for output table 

% Output: 
% (1) savings_cv_cubic: a table summarizing the savings point estimate,
% share of households to treat, total savings for the entire state, and 95%
% confidence intervals for the point estimate on the test set

%% Call CPLEX
addpath 'C:\Progra~1\IBM\ILOG\CPLEX_Studio129\cplex\matlab\x64_win64'

% Setup CPLEX optimization options
opt = cplexoptimset('cplex');
opt.parallel = 1;
opt.threads = 4;
opt.simplex.tolerances.feasibility = 1e-08;
opt.mip.tolerances.integrality = 1e-08;
opt.mip.strategy.variableselect = 3;
opt.mip.strategy.nodeselect = 2;
opt.mip.strategy.lbheur = 1;
opt.mip.limits.cutpasses = -1;
opt.display = 'off';

%% Gather inputs

wave=0;
[X,Y,D,n,ps,Yscale,x1_wave,prevuse_wave,Xscale] = generate_inputs_cubic_rule_reweight(sample,sample,wave,tcost,covariate,SMC);

if sample_wave==0 && target_wave==0 %% pooled sample
    [X_t,Y_t,D_t,n_t,ps_t,Yscale_t,x1_t,prevuse_t,Xscale_t] = generate_inputs_cubic_rule_reweight(sample_training,sample,wave,tcost,covariate,SMC);
    [X_h,Y_h,D_h,n_h,ps_h,Yscale_h,x1_h,prevuse_h,Xscale_h] = generate_inputs_cubic_rule_reweight(sample_testing,sample,wave,tcost,covariate,SMC);
    scaled_att_h= (sum(D_h.*Y_h.*Yscale_h-((1-D_h).*Y_h.*Yscale_h.*ps_h)./(1-ps_h))/sum(D_h))*sum(D_h)/n_h;
    scaled_att_t= (sum(D_t.*Y_t.*Yscale_t-((1-D_t).*Y_t.*Yscale_t.*ps_t)./(1-ps_t))/sum(D_t))*sum(D_t)/n_t;

elseif sample_wave~=0 && target_wave~=0 %% wave-specific 

    wave=3;
    [X3_t,Y3_t,D3_t,n3_t,ps3_t,Yscale3_t,x1_wave3_t,prevuse_wave3_t,Xscale3_t] = generate_inputs_cubic_rule_reweight(sample_training{1},sample,wave,tcost,covariate,SMC);
    [X3_h,Y3_h,D3_h,n3_h,ps3_h,Yscale3_h,x1_wave3_h,prevuse_wave3_h,Xscale3_h] = generate_inputs_cubic_rule_reweight(sample_testing{1},sample,wave,tcost,covariate,SMC);
    scaled_att_wave3_h= (sum(D3_h.*Y3_h-((1-D3_h).*Y3_h.*ps3_h)./(1-ps3_h))/sum(D3_h))*sum(D3_h)/n3_h;
    scaled_att_wave3_t= (sum(D3_t.*Y3_t-((1-D3_t).*Y3_t.*ps3_t)./(1-ps3_t))/sum(D3_t))*sum(D3_t)/n3_t;

    wave=6;
    [X6_t,Y6_t,D6_t,n6_t,ps6_t,Yscale6_t,x1_wave6_t,prevuse_wave6_t,Xscale6_t] = generate_inputs_cubic_rule_reweight(sample_training{2},sample,wave,tcost,covariate,SMC);
    [X6_h,Y6_h,D6_h,n6_h,ps6_h,Yscale6_h,x1_wave6_h,prevuse_wave6_h,Xscale6_h] = generate_inputs_cubic_rule_reweight(sample_testing{2},sample,wave,tcost,covariate,SMC);
    scaled_att_wave6_h= (sum(D6_h.*Y6_h-((1-D6_h).*Y6_h.*ps6_h)./(1-ps6_h))/sum(D6_h))*sum(D6_h)/n6_h;
    scaled_att_wave6_t= (sum(D6_t.*Y6_t-((1-D6_t).*Y6_t.*ps6_t)./(1-ps6_t))/sum(D6_t))*sum(D6_t)/n6_t;

    wave=7;
    [X7_t,Y7_t,D7_t,n7_t,ps7_t,Yscale7_t,x1_wave7_t,prevuse_wave7_t,Xscale7_t] = generate_inputs_cubic_rule_reweight(sample_training{3},sample,wave,tcost,covariate,SMC);
    [X7_h,Y7_h,D7_h,n7_h,ps7_h,Yscale7_h,x1_wave7_h,prevuse_wave7_h,Xscale7_h] = generate_inputs_cubic_rule_reweight(sample_testing{3},sample,wave,tcost,covariate,SMC);
    scaled_att_wave7_h= (sum(D7_h.*Y7_h-((1-D7_h).*Y7_h.*ps7_h)./(1-ps7_h))/sum(D7_h))*sum(D7_h)/n7_h;
    scaled_att_wave7_t= (sum(D7_t.*Y7_t-((1-D7_t).*Y7_t.*ps7_t)./(1-ps7_t))/sum(D7_t))*sum(D7_t)/n7_t;

    X_byset=[X3_t;X6_t;X7_t;X3_h;X6_h;X7_h];
    opower_byset=[repelem(3,size(X3_t,1))';repelem(6,size(X6_t,1))';repelem(7,size(X7_t,1))';...
        repelem(3,size(X3_h,1))';repelem(6,size(X6_h,1))';repelem(7,size(X7_h,1))'];
    Y_byset=[Y3_t;Y6_t;Y7_t;Y3_h;Y6_h;Y7_h];
    D_byset=[D3_t;D6_t;D7_t;D3_h;D6_h;D7_h];
    set_byset=[repelem(1,size(X3_t,1))';repelem(1,size(X6_t,1))';repelem(1,size(X7_t,1))';...
        repelem(0,size(X3_h,1))';repelem(0,size(X6_h,1))';repelem(0,size(X7_h,1))'];
end
%% Estimate g using IPW
g=Y.*(D./ps - (1-D)./(1-ps));

if sample_wave==0 && target_wave==0
     g_t = Y_t.*(D_t./ps_t - (1-D_t)./(1-ps_t)); % elements of W(G) - W(G_0), training
     g_h = Y_h.*(D_h./ps_h - (1-D_h)./(1-ps_h)); % elements of W(G) - W(G_0), testing
elseif sample_wave~=0 && target_wave~=0    
    g3_t = Y3_t.*(D3_t./ps3_t - (1-D3_t)./(1-ps3_t)); % elements of W(G) - W(G_0),training
    g6_t = Y6_t.*(D6_t./ps6_t - (1-D6_t)./(1-ps6_t)); 
    g7_t = Y7_t.*(D7_t./ps7_t - (1-D7_t)./(1-ps7_t)); 
    g3_h = Y3_h.*(D3_h./ps3_h - (1-D3_h)./(1-ps3_h)); % elements of W(G) - W(G_0),testing
    g6_h = Y6_h.*(D6_h./ps6_h - (1-D6_h)./(1-ps6_h)); 
    g7_h = Y7_h.*(D7_h./ps7_h - (1-D7_h)./(1-ps7_h)); 
    g_byset = [g3_t;g6_t;g7_t;g3_h;g6_h;g7_h];
end
%% Define X_target_wave and g_target_wave
% X and g represent the pooled sample, so we define X_target_wave, 
% g_target_wave, Yscale_target_wave for wave-specific calculation
if target_wave==0
    X_target_set=X_h;
    g_target_set=g_h;
    Yscale_target_set=Yscale_h;
    n_target_set=n_h;
    x1_target_set=x1_h;
    prevuse_target_set=prevuse_h;
    Xscale_target_set=Xscale_h;
    scaled_att_target_set=scaled_att_h;    
elseif target_wave==3
    X_target_set=X3_h;
    g_target_set=g3_h;
    Yscale_target_set=Yscale3_h;
    n_target_set=n3_h;
    x1_target_set=x1_wave3_h;
    prevuse_target_set=prevuse_wave3_h;
    Xscale_target_set=Xscale3_h;
    scaled_att_target_set=scaled_att_wave3_h;
elseif target_wave==6
    X_target_set=X6_h;
    g_target_set=g6_h;
    Yscale_target_set=Yscale6_h;
    n_target_set=n6_h;
    x1_target_set=x1_wave6_h;
    prevuse_target_set=prevuse_wave6_h;
    Xscale_target_set=Xscale6_h;
    scaled_att_target_set=scaled_att_wave6_h;
elseif target_wave==7
    X_target_set=X7_h;
    g_target_set=g7_h;
    Yscale_target_set=Yscale7_h;
    n_target_set=n7_h;
    x1_target_set=x1_wave7_h;
    prevuse_target_set=prevuse_wave7_h;
    Xscale_target_set=Xscale7_h;
    scaled_att_target_set=scaled_att_wave7_h;
end

%% Solve for rule using cplex 

% number of variables (including the constant)
k = size(X,2);
bsi=0;

tic1= tic;
% Solve for optimal treatment rule
    % generate input for cplex package
if sample_wave==0 && target_wave==0  
    bsperm = [1:n_t]';
    gn_t = g_t(bsperm,:);
    Xn_t = X_t(bsperm,:);
    [f,f_n,Aineq_d,Aineq_i,bineq,lb,ub,ctype,Ind] = generate_cplex_inputs(bsperm,g_t,X_t,bsi,k,gn_t,Xn_t);

elseif sample_wave~=0 && target_wave~=0    
    if reweight==0
    [f,f_n,Aineq_d,Aineq_i,bineq,lb,ub,ctype,Ind] = generate_cplex_inputs(bsperm,g_wave,X_wave,bsi,k,gn_wave,Xn_wave);
    else
    [f,f_n,Aineq_d,Aineq_i,bineq,lb,ub,ctype,dr_hh,dr,Xu_sample_set,gu_p,gr3_t,gr6_t,gr7_t,gr3_h,gr6_h,gr7_h,...
    Xr3_t,Xr6_t,Xr7_t,Xr3_h,Xr6_h,Xr7_h,Yr3_t,Yr6_t,Yr7_t,Yr3_h,Yr6_h,Yr7_h,Dr3_t,Dr6_t,Dr7_t,Dr3_h,Dr6_h,Dr7_h,Ind_byset] = ...
    generate_cplex_inputs_weight_cross_validation(g,X,k,g_byset,X_byset,Y_byset,D_byset,opower_byset,set_byset,sample_wave,target_wave);
    end
end
    % Cost minimization with treatment decreasing in pre treatment usage
    [sol_pd, v_pd] = cplexmilp(f,Aineq_d,bineq,[],[],[],[],[],lb,ub,ctype,[],opt); 
    % with treatment increasing in pre treatment usage
    [sol_pi, v_pi] = cplexmilp(f,Aineq_i,bineq,[],[],[],[],[],lb,ub,ctype,[],opt);
    
    % get beta and v 
    if (v_pd < v_pi)
            beta = sol_pd(1:k,:);
            v    = v_pd;
        else
            beta = sol_pi(1:k,:);
            v    = v_pi;
    end
    % get solved assignment in_Ghat and cost_saving
    in_Ghat=(X_target_set*beta>0);
    this_cost_saving=nanmean(g_target_set.*in_Ghat)*Yscale_target_set;

fprintf('cplex solution takes %.2f sec\n',toc(tic1))

%% save sample wave data
if sample_wave==0
    X_sample_set=X_t;
    g_sample_set=g_t;
    Yscale_sample_set=Yscale_t;
    n_sample_set=n_t;
    x1_sample_set=x1_t;
    prevuse_sample_set=prevuse_t;
    Xscale_sample_set=Xscale_t;
    scaled_att_sample_set=scaled_att_t;
elseif sample_wave==3
    X_sample_set=X3_t;
    Xr_sample_set=Xr3_t;
    gr_sample_set=gr3_t;
    g_sample_set=g3_t;
    Yscale_sample_set=Yscale3_t;
    n_sample_set=n3_t;
    x1_sample_set=x1_wave3_t;
    prevuse_sample_set=prevuse_wave3_t;
    Xscale_sample_set=Xscale3_t;
    scaled_att_sample_set=scaled_att_wave3_t;
    scaled_att_sample_set_weighted=(sum(Dr3_t.*Yr3_t.*Yscale3_t.*dr_hh-((1-Dr3_t).*Yr3_t.*Yscale3_t.*dr_hh.*ps3_t)./(1-ps3_t))/sum(Dr3_t))*sum(Dr3_t)/n3_t; % Remember dr_hh is sorted by X, so need to use Yr3 and Dr3 here not Y3 and D3
elseif sample_wave==6
    X_sample_set=X6_t;
    Xr_sample_set=Xr6_t;
    gr_sample_set=gr6_t;
    g_sample_set=g6_t;
    Yscale_sample_set=Yscale6_t;
    n_sample_set=n6_t;
    x1_sample_set=x1_wave6_t;
    prevuse_sample_set=prevuse_wave6_t;
    Xscale_sample_set=Xscale6_t;
    scaled_att_sample_set=scaled_att_wave6_t;
    scaled_att_sample_set_weighted=(sum(Dr6_t.*Yr6_t.*Yscale6_t.*dr_hh-((1-Dr6_t).*Yr6_t.*Yscale6_t.*dr_hh.*ps6_t)./(1-ps6_t))/sum(Dr6_t))*sum(Dr6_t)/n6_t; % Remember dr_hh is sorted by X, so need to use Yr3 and Dr3 here not Y3 and D3
elseif sample_wave==7
    X_sample_set=X7_t;
    Xr_sample_set=Xr7_t;
    gr_sample_set=gr7_t;
    g_sample_set=g7_t;
    Yscale_sample_set=Yscale7_t;
    n_sample_set=n7_t;
    x1_sample_set=x1_wave7_t;
    prevuse_sample_set=prevuse_wave7_t;
    Xscale_sample_set=Xscale7_t;
    scaled_att_sample_set=scaled_att_wave7_t;
    scaled_att_sample_set_weighted=(sum(Dr7_t.*Yr7_t.*Yscale7_t.*dr_hh-((1-Dr7_t).*Yr7_t.*Yscale7_t.*dr_hh.*ps7_t)./(1-ps7_t))/sum(Dr7_t))*sum(Dr7_t)/n7_t; % Remember dr_hh is sorted by X, so need to use Yr3 and Dr3 here not Y3 and D3
end
%% Export rule parameters for graphs

switch winter_prevuse % Use month_since_opower=[-21,-1] as pre-treatment months 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 % cost_savings
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end

s_wave=sprintf('%d%d',sample_wave,target_wave);
s_fold=sprintf('%d',fold);
s_split=sprintf('split%.1f',split);


if sample_wave~=0 && target_wave~=0
    if reweight==0
    filename_coefs=sprintf('coef_cv_cubic_%s_%s%s_%s_%s_fold%s.mat',covariate,s_cost,s_winter,s_wave,s_split,s_fold);
    else
    filename_coefs=sprintf('coef_cv_cubic_%s_%s%s_%s_%s_fold%s_reweight.mat',covariate,s_cost,s_winter,s_wave,s_split,s_fold);
    end
    filename = sprintf('%s_%s',fbase,filename_coefs);
    save(fullfile(savedir,filename),'beta',...
    'in_Ghat','x1_target_set','prevuse_target_set','g_target_set','covariate','tcost','Yscale_target_set','n_target_set','X_target_set','Xscale_target_set',...
    'dr','dr_hh','Xu_sample_set','X_sample_set','gu_p','g_sample_set','Yscale_sample_set','Xscale_sample_set','n_sample_set','gr_sample_set',...
    'Xr_sample_set','x1_sample_set','prevuse_sample_set','scaled_att_target_set','scaled_att_sample_set','scaled_att_sample_set_weighted');

elseif sample_wave==0 && target_wave==0
    filename_coefs=sprintf('coef_cv_cubic_%s_%s%s_%s_%s_perm%s.mat',covariate,s_cost,s_winter,s_wave,s_split,s_fold);
    filename = sprintf('%s_%s',fbase,filename_coefs);
    save(fullfile(savedir,filename),'beta',...
    'in_Ghat','x1_target_set','prevuse_target_set','g_target_set','covariate','tcost','Yscale_target_set','n_target_set','X_target_set','Xscale_target_set',...
    'X_sample_set','g_sample_set','Yscale_sample_set','Xscale_sample_set','n_sample_set',...
    'x1_sample_set','prevuse_sample_set','scaled_att_target_set','scaled_att_sample_set','split','s_fold');
end

%% Output for table 
percent=mean(in_Ghat);
savings=nanmean(g_target_set.*in_Ghat)*Yscale_target_set;

if (tcost)
    total_savings=savings*n_target_set*12;
else
    total_savings=savings*n_target_set*12/1000;
end
ci_lb=0;
ci_ub=0;

savings_cv_cubic=nan(1,7);
savings_cv_cubic=array2table(savings_cv_cubic,'VariableNames',{'rules','covariate','percent','savings','totalsavings','cilb','ciub'}); 

savings_cv_cubic.rules='cubic';
savings_cv_cubic.covariate = {covariate};
savings_cv_cubic.percent=percent;
savings_cv_cubic.savings=savings;
savings_cv_cubic.totalsavings=total_savings;
savings_cv_cubic.cilb=ci_lb;
savings_cv_cubic.ciub=ci_ub;
end






